@everywhere using Distributions, Optim, Plots, ProgressMeter, Weave
@everywhere using CSV, Tables, StatsBase, Distributions
@everywhere using Distributed, SharedArrays
@everywhere using ParallelUtilities

S1 = 100 #number of simulations 500
S2 = 100 #number of bootstrap replications 100


#sigma_vals = SharedArray(ones(1,9))
@everywhere sigma_vals =  [0.01 0.05 0.1 .25 0.5 1. 2. 5. 10.]
n_sigma = length(sigma_vals)

## Load data, set sample size
#Please change this path into your personal folder where the data files were placed according to the readme file explanations.
@everywhere application_data = CSV.read("folder/data/sample.csv", Tables.matrix)
@everywhere dU = Logistic(0,1)
#  child  married unemp  retired other 
@everywhere β1 = [-0.030 0.139   -0.188 -0.132  -0.370]
@everywhere β2 = [0.049] # coef on log income
@everywhere ρ = [0.733]    # autoreg on >=3
@everywhere γ2 = [-3.275]
@everywhere γ3 = [0.0]
@everywhere γ4 = [3.326]
@everywhere Kd = 5 # 5 discrete regressors

N = 260601 # is full size
@everywhere N =  500 #100000
@show N
@everywhere application_data = application_data[1:N,:]

## Function takes the data from the empirical application and generates data for one simulation rep.
@everywhere function generate_data()

    draw_bootstrap_sample = function(df)
        n, K = size(df)
        s_index = sample(1:n,n)
        return df[s_index,:] 
    end
    ad_s = draw_bootstrap_sample(application_data)
    ad_s = application_data
    
    Y0 = ad_s[:,2]
    D0 = ifelse.(Y0 .>= 3.,1,0)

    Xc_1 = ad_s[:,3] #income, child, married, unemp, retired, other
    Xd_1 = ad_s[:,[6;  9; 12; 15; 18]] #income, child, married, unemp, retired, other

    Xc_2 = ad_s[:,4]
    Xd_2 = ad_s[:,[7; 10; 13; 16; 19]]

    Xc_3 = ad_s[:,5]
    Xd_3 = ad_s[:,[8; 11; 14; 17; 20]]
    
    n = length(Y0)
    
    dA = Normal(0,1)
    A = rand(dA,n)

    U1 = rand(dU,n)
    U2 = rand(dU,n)
    U3 = rand(dU,n)

    Y1_star = Xd_1*β1' + Xc_1.*β2 + U1 + A + D0.*ρ
    Y1 = ones(n)
    Y1 = Y1 .+ ifelse.(Y1_star.> γ2,1,0)
    Y1 = Y1 .+ ifelse.(Y1_star.> γ3,1,0)
    Y1 = Y1 .+ ifelse.(Y1_star.> γ4,1,0)
    D1 = ifelse.(Y1_star.> γ3,1,0)

    Y2_star = Xd_2*β1' + Xc_1.*β2 + U2 + A + D1.*ρ
    Y2 = ones(n)
    Y2 = Y2 .+ ifelse.(Y2_star.> γ2,1,0)
    Y2 = Y2 .+ ifelse.(Y2_star.> γ3,1,0)
    Y2 = Y2 .+ ifelse.(Y2_star.> γ4,1,0)
    D2 = ifelse.(Y2_star.> γ3,1,0)

    Y3_star = Xd_3*β1' + Xc_1.*β2 + U3 + A + D2.*ρ
    Y3 = ones(n)
    Y3 = Y3 .+ ifelse.(Y3_star.> γ2,1,0)
    Y3 = Y3 .+ ifelse.(Y3_star.> γ3,1,0)
    Y3 = Y3 .+ ifelse.(Y3_star.> γ4,1,0)
    D3 = ifelse.(Y3_star.> γ3,1,0)

    ## Glue the discrete and continous regressors
    X_1 = [Xd_1 Xc_1]
    X_2 = [Xd_2 Xc_2]
    X_3 = [Xd_3 Xc_3]
    
    return Y0, Y1, Y2, Y3, D0, D1, D2, D3, X_1, X_2, X_3, Xd_2, Xd_3, Xc_2, Xc_3

end

generate_data()

##
## This function implements the MRV estimator. It:
## - takes in Y0---X_3 (as in `generate_data`) and 
## - spits out a matrix:
##    - of estimated coefficients
##    - for each value of the bandwidth
##

@everywhere function theta_hat_mrv(Y0, Y1, Y2, Y3, D0, D1, D2, D3, X_1, X_2, X_3, Xd_2, Xd_3, Xc_2, Xc_3)

    ## Effective change in X (middle period)
    DX = X_2 - X_1
    # definition of switcher for binarized HK / binarized FEOL
    S = ifelse.(D1 .+ D2 .== 1,1,0)

    ## Binary indicator that is 1 if discrete Xs do not change
    ##  from period 2 to 3.
    fix_Xd = minimum(Xd_2 .== Xd_3, dims = 2)

    DX_f = DX[fix_Xd[:,1],:]
    Xc_3_f = Xc_3[fix_Xd[:,1]]
    Xc_2_f = Xc_2[fix_Xd[:,1]]
    D1_f = D1[fix_Xd[:,1]]
    D2_f = D2[fix_Xd[:,1]]
    D3_f = D3[fix_Xd[:,1]]
    D0_f = D0[fix_Xd[:,1]]
    S_f = S[fix_Xd[:,1]]
    Y2_f = Y2[fix_Xd[:,1]]
    Y3_f = Y3[fix_Xd[:,1]]

    function ll_mrv(brc, sigma)

        bb1 = brc[1:Kd]
        bb2 = brc[Kd+1]
        r = brc[Kd+2]
        c = [0.0 brc[Kd+3] 0.0 brc[Kd+4]]  #first one is dummy, gamma2, gamm3, gamm4

        gaus_kernel = Normal(0,1)
        #Xc_weight = pdf.(gaus_kernel,(Xc_3 - Xc_2)./sigma)
        Xc_weight = pdf.(gaus_kernel,(Xc_3_f - Xc_2_f)./sigma)

        rval = 0.0
        for j in 2:3
            for l in 3:4
                # Compute D2 
                D2l = ifelse.(Y2_f .>= l,1,0)
                D2j = ifelse.(Y2_f .>= j,1,0)
                D2jl = ifelse.(D1_f .== 1, D2j, D2l)
                
                # Compute D3
                D3j = ifelse.(Y3_f .>= j,1,0)
                D3l = ifelse.(Y3_f .>= l,1,0)
                D3jl = ifelse.(D1_f .== 1, D3l, D3j)

                Smrv = ifelse.(D1_f .== D2jl,0,1)

                Zjl = hcat(-DX_f,D0_f .- D3jl, D3jl, 1.0 .- D3jl)
                thetajl = vcat(bb1,bb2,r,c[j],c[l])
                index = Zjl * thetajl

                #rval = rval -sum(fix_Xd .* Xc_weight .* Smrv .* (D1 .* log.(cdf(dU,index)) + (1 .- D1) .* log.(1 .-cdf(dU,index))))
                rval = rval -sum(Xc_weight .* Smrv .* (D1_f .* log.(cdf(dU,index)) + (1 .- D1_f) .* log.(1 .-cdf(dU,index))))
            end
        end
        return rval
    end #of the ll_mrv likelihood definition


    ## Values of the bandwidth we are considering``
    
    b2_mrv = zeros(n_sigma)
    r_mrv = zeros(n_sigma)
    g2_mrv = zeros(n_sigma)
    g4_mrv = zeros(n_sigma)
    for sig_index in 1:n_sigma
        ## setup the optimization problem
        func_mrv = TwiceDifferentiable(brc -> ll_mrv(brc, sigma_vals[sig_index]), ones(9); autodiff=:forward)
        yay_mrv = optimize(func_mrv, [β1';β2;ρ;γ2;γ4])

        ## extract the coefficient estimates
        b2_mrv[sig_index] = Optim.minimizer(yay_mrv)[Kd+1]
        r_mrv[sig_index] = Optim.minimizer(yay_mrv)[Kd+2]
        g2_mrv[sig_index] = Optim.minimizer(yay_mrv)[Kd+3]
        g4_mrv[sig_index] = Optim.minimizer(yay_mrv)[Kd+4]
    end

    return b2_mrv, r_mrv, g2_mrv, g4_mrv
    
end

## toplevel stuff
d = Normal()
za = quantile(d, 0.975)

coverage = SharedArray(ones(n_sigma))
	#coverage = ones(n_sigma)

width = SharedArray(zeros(n_sigma))
	#width = zeros(n_sigma)

width_S1 = SharedArray(zeros(S1,n_sigma))
 	#width_S1 = zeros(S1,n_sigma)

is_it_in = SharedArray(ones(S1,n_sigma))
 	#is_it_in = ones(S1,n_sigma)

#function to keep @distributed within the function

@sync @distributed for s in 1:S1 #async here no... it will not stop
#@distributed for s in 1:S1 distributed would not continue to the next step until first finishes (this sentence is not correct but it is in these lines)
@show s

    ## Generate the data 
    Y0, Y1, Y2, Y3, D0, D1, D2, D3, X_1, X_2, X_3, Xd_2, Xd_3, Xc_2, Xc_3 = generate_data()
#@show Y0, Y1, Y2, Y3, D0, D1, D2, D3, X_1, X_2, X_3, Xd_2, Xd_3, Xc_2, Xc_3
#@show Y0
#@show sigma_vals
    ## Point estimate
    b2, r, g2, g4 = theta_hat_mrv(Y0, Y1, Y2, Y3, D0, D1, D2, D3, X_1, X_2, X_3, Xd_2, Xd_3, Xc_2, Xc_3)

    ## innerloop stuff
    #b_bs = SharedArray(zeros(S2,n_sigma));	
    b_bs = zeros(S2,n_sigma)
    for j in 1:S2 #@async for j in 1:S2 each worker is using different betas

        n = length(Y0)
        s_index = sample(1:n,n)
	#@show n
	#@show s_index

        Y0j = Y0[s_index]
        Y1j = Y1[s_index] 
        Y2j = Y2[s_index]
        Y3j = Y3[s_index]
        D0j = D0[s_index] 
        D1j = D1[s_index]
        D2j = D2[s_index] 
        D3j = D3[s_index] 
        X_1j = X_1[s_index,:] 
        X_2j = X_2[s_index,:] 
        X_3j = X_3[s_index,:] 
        Xd_2j = Xd_2[s_index,:] 
        Xd_3j = Xd_3[s_index,:] 
        Xc_2j = Xc_2[s_index]
        Xc_3j = Xc_3[s_index]

        b2_j, r_j, g2_j, g4_j = theta_hat_mrv(Y0j, Y1j, Y2j, Y3j, D0j, D1j, D2j, D3j, X_1j, X_2j, X_3j, Xd_2j, Xd_3j, Xc_2j, Xc_3j)

        b_bs[j,:] = b2_j
#@show b_bs
    end

    #b_bias_bs = SharedArray(ones(n_sigma));	
    #b_se_bs = SharedArray(ones(n_sigma));	
    b_bias_bs = ones(n_sigma)
    b_se_bs = ones(n_sigma)

    for ii in 1:n_sigma #@async takes more time

#@show b_se_bs
        b_bias_bs[ii] = mean(b_bs[:,ii])
        b_se_bs[ii] = std(b_bs[:,ii])
        
        if (b2[ii] - za*b_se_bs[ii]) > β2[1]
            is_it_in[s,ii]=0
        end
        if (b2[ii] + za*b_se_bs[ii]) < β2[1]
            is_it_in[s,ii]=0
        end

        width_S1[s,ii] = 2*za*b_se_bs[ii]
    end

#@show width_S1
#@show is_it_in
end

## Final loop over sigmas, average over simulation replications
for ii in 1:n_sigma    ##@async takes more time
    coverage[ii] = mean(is_it_in[:,ii])
    width[ii] = mean(width_S1[:,ii])
end

#@show width_S1
#@show is_it_in
@show coverage
@show width

np = nprocs()
nw = nworkers()
println("number of processes : $np")
println("number of workers : $nw")